{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate model on val dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torchvision.transforms as transforms\n",
    "from munch import Munch\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "import datasets\n",
    "import models\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = 'pretrained/floating_kinect1_object/config.yml'\n",
    "#config_path = 'pretrained/floating_kinect1_mask/config.yml'\n",
    "#config_path = 'pretrained/floating_kinect1_mask_with_occlusion/config.yml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(config_path, 'r') as f:\n",
    "    cfg = Munch.fromYAML(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> loaded checkpoint 'pretrained/floating_kinect1_object/checkpoint_00002100.pth.tar' (epoch 2100)\n"
     ]
    }
   ],
   "source": [
    "model = models.Model(cfg.arch)\n",
    "model = torch.nn.DataParallel(model).cuda()\n",
    "cudnn.benchmark = True\n",
    "checkpoint = torch.load(cfg.training.resume)\n",
    "model.load_state_dict(checkpoint['state_dict'])\n",
    "model.eval()\n",
    "print(\"=> loaded checkpoint '{}' (epoch {})\".format(cfg.training.resume, checkpoint['epoch']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.ToTensor()\n",
    "val_dataset = datasets.RenderedPoseDataset(\n",
    "    cfg.data.root, cfg.data.objects, cfg.data.val_subset_num, transform)\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_batch(model, input, target, object_index):\n",
    "    target = target.cuda(non_blocking=True)\n",
    "    object_index = object_index.cuda(non_blocking=True)\n",
    "\n",
    "    position, orientation = model(input, object_index)\n",
    "    position_error = (target[:, :3] - position).pow(2).sum(dim=1).sqrt()\n",
    "    orientation_error = 180.0 / np.pi * utils.batch_rotation_angle(target[:, 3:], orientation)\n",
    "\n",
    "    return position_error.mean(), orientation_error.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89acf562285c4fef893488f402fd7f18",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=3200), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "overall position error: 1.12 meters\n",
      "overall orientation error: 8.93 degrees\n"
     ]
    }
   ],
   "source": [
    "position_errors = utils.AverageMeter()\n",
    "orientation_errors = utils.AverageMeter()\n",
    "with torch.no_grad():\n",
    "    for input, target, object_index in tqdm(val_loader):\n",
    "        position_error, orientation_error = forward_batch(model, input, target, object_index)\n",
    "        position_errors.update(position_error, input.size(0))\n",
    "        orientation_errors.update(orientation_error, input.size(0))\n",
    "print('overall position error: {:.2f} meters'.format(100 * position_errors.avg))\n",
    "print('overall orientation error: {:.2f} degrees'.format(orientation_errors.avg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
